Skip to content

Fix NotImplemented errors, xformers attention shape, and missing text conditioning#140

Open
Mr-Neutr0n wants to merge 1 commit intoVchitect:mainfrom
Mr-Neutr0n:fix/notimplemented-and-attention-bugs
Open

Fix NotImplemented errors, xformers attention shape, and missing text conditioning#140
Mr-Neutr0n wants to merge 1 commit intoVchitect:mainfrom
Mr-Neutr0n:fix/notimplemented-and-attention-bugs

Conversation

@Mr-Neutr0n
Copy link
Copy Markdown

Summary

  • raise NotImplemented -> raise NotImplementedError in models/latte.py (line 73) and models/latte_img.py (line 76). NotImplemented is a special singleton used for binary operator fallbacks, not an exception class. Using it with raise produces a TypeError instead of the intended error, masking the real issue.

  • Transpose q, k, v for xformers attention in models/latte_img.py (line 61). After permute(2, 0, 3, 1, 4) and unbind, tensors are shaped (B, heads, N, dim), but xformers.ops.memory_efficient_attention expects (B, N, heads, dim). The implementation in models/latte.py (lines 55-58) correctly transposes before calling xformers; this patch applies the same fix to latte_img.py.

  • Add missing elif self.extras == 78 before final layer in models/latte.py (line 372). The temporal block loop correctly handles extras == 78 by adding text_embedding_temp to the conditioning, but the final adaptive layer norm block only checked for extras == 2 (class conditioning) and fell through to unconditional for all other values. This meant text-conditioned generation (extras == 78) silently dropped text conditioning at the final layer.

Test plan

  • Verify raise NotImplementedError is correctly raised when an unsupported attention mode is passed
  • Run inference with xformers attention mode using latte_img.py and confirm no shape errors
  • Run text-conditioned video generation and verify text conditioning is applied through the final layer

… conditioning

- Replace `raise NotImplemented` with `raise NotImplementedError` in both
  latte.py and latte_img.py. `NotImplemented` is not an exception class and
  will raise a TypeError instead of the intended error.

- Transpose q, k, v from (B, heads, N, dim) to (B, N, heads, dim) before
  calling xformers memory_efficient_attention in latte_img.py, matching the
  correct implementation in latte.py. xformers expects the (B, N, heads, dim)
  layout.

- Add missing `elif self.extras == 78` branch before the final layer in
  latte.py so that text_embedding_spatial conditioning is applied during
  the final adaptive layer norm, consistent with the temporal blocks above.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant